{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c64d3134-4dec-44ff-9f5e-55b033c48edb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset\n",
    "import time\n",
    "import glob\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "01acdbde-f012-45db-ae94-beff97baa8f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path = glob.glob('./train_set/*/*.png')\n",
    "np.random.shuffle(train_path)\n",
    "\n",
    "train_label_index = [\n",
    "    \"Speed limit (5km)\",         # 0\n",
    "    \"Speed limit (15km)\",        # 1\n",
    "    \"Speed limit (20km)\",        # 2\n",
    "    \"Speed limit (30km)\",        # 3\n",
    "    \"Speed limit (40km)\",        # 4\n",
    "    \"Speed limit (50km)\",        # 5\n",
    "    \"Speed limit (60km)\",        # 6\n",
    "    \"Speed limit (70km)\",        # 7\n",
    "    \"Speed limit (80km)\",        # 8\n",
    "    \"Speed limit (100km)\",       # 9\n",
    "    \"Speed limit (120km)\",       # 10\n",
    "    \"End of speed limit\",        # 11\n",
    "    \"End of speed limit (50km)\", # 12\n",
    "    \"End of speed limit (80km)\", # 13\n",
    "    \"Dont overtake from Left\",   # 14\n",
    "    \"No stopping\",               # 15\n",
    "    \"No Uturn\",                  # 16\n",
    "    \"No Car\",                    # 17\n",
    "    \"No horn\",                   # 18\n",
    "    \"No entry\",                  # 19\n",
    "    \"No passage\",                # 20\n",
    "    \"Dont Go Right\",             # 21\n",
    "    \"Dont Go Left or Right\",     # 22\n",
    "    \"Dont Go Left\",              # 23\n",
    "    \"Dont Go straight\",          # 24\n",
    "    \"Dont Go straight or Right\", # 25\n",
    "    \"Dont Go straight or left\",  # 26\n",
    "    \"Go right or straight\",      # 27\n",
    "    \"Go left or straight\",       # 28\n",
    "    \"Village\",                   # 29\n",
    "    \"Uturn\",                     # 30\n",
    "    \"ZigZag Curve\",              # 31\n",
    "    \"Bicycles crossing\",         # 32\n",
    "    \"Keep Right\",                # 33\n",
    "    \"Keep Left\",                 # 34\n",
    "    \"Roundabout mandatory\",      # 35\n",
    "    \"Watch out for cars\",        # 36\n",
    "    \"Slow down and give way\",    # 37\n",
    "    \"Continuous detours\",        # 38\n",
    "    \"Slow walking\",              # 39\n",
    "    \"Horn\",                      # 40\n",
    "    \"Uphill steep slope\",        # 41\n",
    "    \"Downhill steep slope\",      # 42\n",
    "    \"Under Construction\",        # 43\n",
    "    \"Heavy Vehicle Accidents\",   # 44\n",
    "    \"Parking inspection\",        # 45\n",
    "    \"Stop at intersection\",      # 46\n",
    "    \"Train Crossing\",            # 47\n",
    "    \"Fences\",                    # 48\n",
    "    \"Dangerous curve to the right\", # 49\n",
    "    \"Go Right\",                  # 50\n",
    "    \"Go Left or right\",          # 51\n",
    "    \"Dangerous curve to the left\", # 52\n",
    "    \"Go Left\",                   # 53\n",
    "    \"Go straight\",               # 54\n",
    "    \"Go straight or right\",      # 55\n",
    "    \"Children crossing\",         # 56\n",
    "    \"Care bicycles crossing\",    # 57\n",
    "    \"Danger Ahead\",              # 58\n",
    "    \"Traffic signals\",           # 59\n",
    "    \"Zebra Crossing\",            # 60\n",
    "    \"Road Divider\"               # 61\n",
    "]\n",
    "train_label_index = [x.lower() for x in train_label_index]\n",
    "train_label = [train_label_index.index(x.split('/')[-2].lower()) for x in train_path]\n",
    "\n",
    "test_path = glob.glob('./test_set/*.png')\n",
    "test_path.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "cd793915-cb4d-46a1-8724-376d52c321f8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9808"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "7c87b426-c326-4a3e-a3fe-3c69ccf2fa5b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['./test_set/005d3487-92a5-48f3-9784-379f1217825d.png',\n",
       " './test_set/00a6a262-1643-443f-b8ed-7fc075ac3c1e.png',\n",
       " './test_set/00f7f66d-96f6-4c9c-b476-dced05e7211f.png',\n",
       " './test_set/01a4038f-159e-4a73-bc65-2d20b3cc2a2d.png',\n",
       " './test_set/01cf52dd-4235-403b-9696-00490042dcbd.png',\n",
       " './test_set/01d298fe-8d03-4ad4-9547-362c9c982438.png',\n",
       " './test_set/01ed9a77-80a1-40ba-a14f-434d8cae1fc4.png',\n",
       " './test_set/01f3eadb-adc3-4676-afa6-41558ffcee0f.png',\n",
       " './test_set/025d6acb-021f-4dd2-9fc2-6a6ea1c25d7f.png',\n",
       " './test_set/027848aa-c988-4fe8-a9d7-cedbe2671e2e.png']"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_path[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "60781cb3-1943-4951-a5ca-4a940910e4a7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "62"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_label_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "ab48904c-3f33-4299-a6ab-c44ee9a40670",
   "metadata": {},
   "outputs": [],
   "source": [
    "class AverageMeter(object):\n",
    "    \"\"\"Computes and stores the average and current value\"\"\"\n",
    "    def __init__(self, name, fmt=':f'):\n",
    "        self.name = name\n",
    "        self.fmt = fmt\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.val = 0\n",
    "        self.avg = 0\n",
    "        self.sum = 0\n",
    "        self.count = 0\n",
    "\n",
    "    def update(self, val, n=1):\n",
    "        self.val = val\n",
    "        self.sum += val * n\n",
    "        self.count += n\n",
    "        self.avg = self.sum / self.count\n",
    "\n",
    "    def __str__(self):\n",
    "        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'\n",
    "        return fmtstr.format(**self.__dict__)\n",
    "\n",
    "class ProgressMeter(object):\n",
    "    def __init__(self, num_batches, *meters):\n",
    "        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)\n",
    "        self.meters = meters\n",
    "        self.prefix = \"\"\n",
    "\n",
    "\n",
    "    def pr2int(self, batch):\n",
    "        entries = [self.prefix + self.batch_fmtstr.format(batch)]\n",
    "        entries += [str(meter) for meter in self.meters]\n",
    "        print('\\t'.join(entries))\n",
    "\n",
    "    def _get_batch_fmtstr(self, num_batches):\n",
    "        num_digits = len(str(num_batches // 1))\n",
    "        fmt = '{:' + str(num_digits) + 'd}'\n",
    "        return '[' + fmt + '/' + fmt.format(num_batches) + ']'\n",
    "def validate(val_loader, model, criterion):\n",
    "    batch_time = AverageMeter('Time', ':6.3f')\n",
    "    losses = AverageMeter('Loss', ':.4e')\n",
    "    top1 = AverageMeter('Acc@1', ':6.2f')\n",
    "    progress = ProgressMeter(len(val_loader), batch_time, losses, top1)\n",
    "\n",
    "    # switch to evaluate mode\n",
    "    model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "\n",
    "            # measure accuracy and record loss\n",
    "            acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100\n",
    "            losses.update(loss.item(), input.size(0))\n",
    "            top1.update(acc, input.size(0))\n",
    "            # measure elapsed time\n",
    "            batch_time.update(time.time() - end)\n",
    "            end = time.time()\n",
    "\n",
    "        # TODO: this should also be done with the ProgressMeter\n",
    "        print(' * Acc@1 {top1.avg:.3f}'\n",
    "              .format(top1=top1))\n",
    "        return top1\n",
    "\n",
    "def predict(test_loader, model, tta=10):\n",
    "    # switch to evaluate mode\n",
    "    model.eval()\n",
    "    \n",
    "    test_pred_tta = None\n",
    "    for _ in range(tta):\n",
    "        test_pred = []\n",
    "        with torch.no_grad():\n",
    "            end = time.time()\n",
    "            for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):\n",
    "                input = input.cuda()\n",
    "                target = target.cuda()\n",
    "\n",
    "                # compute output\n",
    "                output = model(input)\n",
    "                output = F.softmax(output, dim=1)\n",
    "                output = output.data.cpu().numpy()\n",
    "\n",
    "                test_pred.append(output)\n",
    "        test_pred = np.vstack(test_pred)\n",
    "    \n",
    "        if test_pred_tta is None:\n",
    "            test_pred_tta = test_pred\n",
    "        else:\n",
    "            test_pred_tta += test_pred\n",
    "    \n",
    "    return test_pred_tta\n",
    "\n",
    "def train(train_loader, model, criterion, optimizer, epoch):\n",
    "    batch_time = AverageMeter('Time', ':6.3f')\n",
    "    losses = AverageMeter('Loss', ':.4e')\n",
    "    top1 = AverageMeter('Acc@1', ':6.2f')\n",
    "    progress = ProgressMeter(len(train_loader), batch_time, losses, top1)\n",
    "\n",
    "    # switch to train mode\n",
    "    model.train()\n",
    "\n",
    "    end = time.time()\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target = target.cuda(non_blocking=True)\n",
    "\n",
    "        # compute output\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # measure accuracy and record loss\n",
    "        losses.update(loss.item(), input.size(0))\n",
    "\n",
    "        acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100\n",
    "        top1.update(acc, input.size(0))\n",
    "\n",
    "        # compute gradient and do SGD step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # measure elapsed time\n",
    "        batch_time.update(time.time() - end)\n",
    "        end = time.time()\n",
    "\n",
    "        if i % 100 == 0:\n",
    "            progress.pr2int(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "dde84a60-5939-4b69-b0a9-bb43773186b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class XFDataset(Dataset):\n",
    "    def __init__(self, img_path, img_label, transform=None):\n",
    "        self.img_path = img_path\n",
    "        self.img_label = img_label\n",
    "        \n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(self.img_path[index]).convert('RGB')\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        return img, torch.from_numpy(np.array(self.img_label[index]))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "aad4267f-b007-48a4-8e06-cf89df9b9f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
    "\n",
    "import timm\n",
    "model = timm.create_model('resnet18', pretrained=True, num_classes=62)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "abc0c12c-5cdd-42fd-bfd6-9027431e46e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/torch/optim/lr_scheduler.py:143: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
      "  warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[  0/466]\tTime  0.299 ( 0.299)\tLoss 4.0628e+00 (4.0628e+00)\tAcc@1   0.00 (  0.00)\n",
      "[100/466]\tTime  0.048 ( 0.046)\tLoss 1.5523e+00 (2.6152e+00)\tAcc@1  40.00 ( 29.60)\n",
      "[200/466]\tTime  0.043 ( 0.044)\tLoss 7.8758e-01 (1.9598e+00)\tAcc@1  75.00 ( 45.05)\n",
      "[300/466]\tTime  0.044 ( 0.044)\tLoss 4.9413e-01 (1.5744e+00)\tAcc@1  85.00 ( 55.40)\n",
      "[400/466]\tTime  0.039 ( 0.044)\tLoss 5.3654e-01 (1.3211e+00)\tAcc@1  75.00 ( 61.97)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2510832/2676678360.py:51: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5a45a997b568437e9185d463c9a86d53",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 82.200\n",
      "Epoch:  1\n",
      "[  0/466]\tTime  0.303 ( 0.303)\tLoss 2.7170e-01 (2.7170e-01)\tAcc@1  90.00 ( 90.00)\n",
      "[100/466]\tTime  0.045 ( 0.046)\tLoss 4.4853e-01 (3.3013e-01)\tAcc@1  80.00 ( 89.70)\n",
      "[200/466]\tTime  0.043 ( 0.044)\tLoss 9.9414e-01 (2.9551e-01)\tAcc@1  80.00 ( 90.87)\n",
      "[300/466]\tTime  0.043 ( 0.044)\tLoss 3.0984e-01 (2.7762e-01)\tAcc@1  90.00 ( 91.66)\n",
      "[400/466]\tTime  0.043 ( 0.044)\tLoss 1.4144e-01 (2.6402e-01)\tAcc@1  95.00 ( 92.04)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df11100ba71c4e51bf1ac701d162fbfe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 91.000\n",
      "Epoch:  2\n",
      "[  0/466]\tTime  0.320 ( 0.320)\tLoss 4.3136e-01 (4.3136e-01)\tAcc@1  90.00 ( 90.00)\n",
      "[100/466]\tTime  0.049 ( 0.046)\tLoss 3.5571e-02 (2.3843e-01)\tAcc@1 100.00 ( 92.57)\n",
      "[200/466]\tTime  0.045 ( 0.045)\tLoss 4.5550e-01 (1.7889e-01)\tAcc@1  90.00 ( 94.40)\n",
      "[300/466]\tTime  0.045 ( 0.044)\tLoss 1.5154e-01 (1.5645e-01)\tAcc@1  95.00 ( 94.98)\n",
      "[400/466]\tTime  0.043 ( 0.044)\tLoss 2.3688e-02 (1.5547e-01)\tAcc@1 100.00 ( 94.96)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ece32752f2fb4caeb0c92de5db528d95",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 96.200\n",
      "Epoch:  3\n",
      "[  0/466]\tTime  0.282 ( 0.282)\tLoss 2.0082e-01 (2.0082e-01)\tAcc@1  85.00 ( 85.00)\n",
      "[100/466]\tTime  0.043 ( 0.046)\tLoss 5.5439e-03 (8.4757e-02)\tAcc@1 100.00 ( 97.28)\n",
      "[200/466]\tTime  0.044 ( 0.044)\tLoss 2.0155e-01 (8.0886e-02)\tAcc@1  90.00 ( 97.31)\n",
      "[300/466]\tTime  0.042 ( 0.044)\tLoss 5.9428e-03 (8.5747e-02)\tAcc@1 100.00 ( 97.09)\n",
      "[400/466]\tTime  0.044 ( 0.044)\tLoss 3.1611e-01 (8.1407e-02)\tAcc@1  95.00 ( 97.24)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "706d01039d0948d99c7ce3b5426afdee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 98.200\n",
      "Epoch:  4\n",
      "[  0/466]\tTime  0.308 ( 0.308)\tLoss 3.0572e-01 (3.0572e-01)\tAcc@1  85.00 ( 85.00)\n",
      "[100/466]\tTime  0.045 ( 0.046)\tLoss 9.5078e-04 (6.4377e-02)\tAcc@1 100.00 ( 98.17)\n",
      "[200/466]\tTime  0.044 ( 0.045)\tLoss 1.3104e-01 (7.4692e-02)\tAcc@1  95.00 ( 97.54)\n",
      "[300/466]\tTime  0.043 ( 0.044)\tLoss 1.3241e-01 (8.7034e-02)\tAcc@1  95.00 ( 97.19)\n",
      "[400/466]\tTime  0.044 ( 0.044)\tLoss 6.1183e-02 (9.2482e-02)\tAcc@1 100.00 ( 97.12)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84862eacf1034560951a2d835998969f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 94.600\n",
      "Epoch:  5\n",
      "[  0/466]\tTime  0.304 ( 0.304)\tLoss 1.7121e-01 (1.7121e-01)\tAcc@1  85.00 ( 85.00)\n",
      "[100/466]\tTime  0.048 ( 0.046)\tLoss 1.1815e-01 (6.8504e-02)\tAcc@1  95.00 ( 97.43)\n",
      "[200/466]\tTime  0.046 ( 0.045)\tLoss 1.6653e-02 (6.5047e-02)\tAcc@1 100.00 ( 97.76)\n",
      "[300/466]\tTime  0.046 ( 0.044)\tLoss 1.1540e-01 (7.5737e-02)\tAcc@1  95.00 ( 97.44)\n",
      "[400/466]\tTime  0.040 ( 0.044)\tLoss 1.6249e-02 (8.0571e-02)\tAcc@1 100.00 ( 97.27)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d1aef1e688914ab1b08a1d804af37299",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 97.000\n",
      "Epoch:  6\n",
      "[  0/466]\tTime  0.293 ( 0.293)\tLoss 2.0243e-01 (2.0243e-01)\tAcc@1  90.00 ( 90.00)\n",
      "[100/466]\tTime  0.045 ( 0.046)\tLoss 2.1786e-01 (7.4438e-02)\tAcc@1  95.00 ( 97.92)\n",
      "[200/466]\tTime  0.042 ( 0.045)\tLoss 8.6450e-03 (7.5136e-02)\tAcc@1 100.00 ( 97.51)\n",
      "[300/466]\tTime  0.043 ( 0.044)\tLoss 2.1358e-02 (6.5498e-02)\tAcc@1 100.00 ( 97.86)\n",
      "[400/466]\tTime  0.043 ( 0.044)\tLoss 4.0037e-02 (5.9442e-02)\tAcc@1 100.00 ( 98.08)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27f991e15b724c8784034a578e6fdea1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 97.800\n",
      "Epoch:  7\n",
      "[  0/466]\tTime  0.289 ( 0.289)\tLoss 7.3678e-03 (7.3678e-03)\tAcc@1 100.00 (100.00)\n",
      "[100/466]\tTime  0.044 ( 0.046)\tLoss 3.9803e-02 (4.6332e-02)\tAcc@1  95.00 ( 97.97)\n",
      "[200/466]\tTime  0.043 ( 0.045)\tLoss 6.6214e-04 (3.4386e-02)\tAcc@1 100.00 ( 98.63)\n",
      "[300/466]\tTime  0.044 ( 0.044)\tLoss 9.3362e-03 (3.3074e-02)\tAcc@1 100.00 ( 98.75)\n",
      "[400/466]\tTime  0.042 ( 0.044)\tLoss 8.4006e-03 (4.1172e-02)\tAcc@1 100.00 ( 98.48)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c2e77ccbd6c48839d16f8272a8c92ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 98.200\n",
      "Epoch:  8\n",
      "[  0/466]\tTime  0.300 ( 0.300)\tLoss 5.4997e-02 (5.4997e-02)\tAcc@1  95.00 ( 95.00)\n",
      "[100/466]\tTime  0.043 ( 0.046)\tLoss 2.2293e-02 (4.9052e-02)\tAcc@1 100.00 ( 98.42)\n",
      "[200/466]\tTime  0.042 ( 0.045)\tLoss 5.4809e-02 (5.5094e-02)\tAcc@1  95.00 ( 98.06)\n",
      "[300/466]\tTime  0.044 ( 0.044)\tLoss 3.8192e-03 (5.2097e-02)\tAcc@1 100.00 ( 98.16)\n",
      "[400/466]\tTime  0.044 ( 0.044)\tLoss 2.8451e-04 (4.7325e-02)\tAcc@1 100.00 ( 98.32)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c41b3a78ac5458aa48354207d800d2f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 98.400\n",
      "Epoch:  9\n",
      "[  0/466]\tTime  0.299 ( 0.299)\tLoss 1.4325e-03 (1.4325e-03)\tAcc@1 100.00 (100.00)\n",
      "[100/466]\tTime  0.044 ( 0.046)\tLoss 1.4930e-02 (2.6899e-02)\tAcc@1 100.00 ( 99.11)\n",
      "[200/466]\tTime  0.043 ( 0.045)\tLoss 1.2194e-02 (4.1103e-02)\tAcc@1 100.00 ( 98.73)\n",
      "[300/466]\tTime  0.044 ( 0.044)\tLoss 2.5995e-02 (5.4468e-02)\tAcc@1 100.00 ( 98.41)\n",
      "[400/466]\tTime  0.048 ( 0.044)\tLoss 8.7019e-03 (5.1813e-02)\tAcc@1 100.00 ( 98.45)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92931f0bbe4d4faabe2c8dd5a616bfaf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 98.600\n"
     ]
    }
   ],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XFDataset(train_path[:-500], train_label[:-500], \n",
    "            transforms.Compose([\n",
    "                        transforms.Resize((256, 256)),\n",
    "                        transforms.RandomHorizontalFlip(),\n",
    "                        transforms.RandomVerticalFlip(),\n",
    "                        transforms.ColorJitter(brightness=.5, hue=.3),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ), batch_size=20, shuffle=True, num_workers=4, pin_memory=True\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XFDataset(train_path[-500:], train_label[-500:], \n",
    "            transforms.Compose([\n",
    "                        transforms.Resize((256, 256)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ), batch_size=20, shuffle=False, num_workers=4, pin_memory=True\n",
    ")\n",
    "\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.Adam(model.parameters(), 0.005)\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)\n",
    "best_acc = 0.0\n",
    "for epoch in range(10):\n",
    "    scheduler.step()\n",
    "    print('Epoch: ', epoch)\n",
    "\n",
    "    train(train_loader, model, criterion, optimizer, epoch)\n",
    "    val_acc = validate(val_loader, model, criterion)\n",
    "    \n",
    "    if val_acc.avg.item() > best_acc:\n",
    "        best_acc = round(val_acc.avg.item(), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "41e86d06-e364-4e3f-8631-c2f3c0722e35",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2510832/2676678360.py:81: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8dad4ee28c214b91afe816986f8f477a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XFDataset(test_path, [0] * len(test_path), \n",
    "            transforms.Compose([\n",
    "                        transforms.Resize((256, 256)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True\n",
    ")\n",
    "\n",
    "val_label = pd.DataFrame()\n",
    "val_label['ImageID'] = [x.split('/')[-1] for x in test_path]\n",
    "val_label['label'] = predict(test_loader, model, 1).argmax(1)\n",
    "val_label.to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "5dc9d50b-9959-4cd7-acfb-91644ccddac6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ImageID</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>005d3487-92a5-48f3-9784-379f1217825d.png</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>00a6a262-1643-443f-b8ed-7fc075ac3c1e.png</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>00f7f66d-96f6-4c9c-b476-dced05e7211f.png</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>01a4038f-159e-4a73-bc65-2d20b3cc2a2d.png</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>01cf52dd-4235-403b-9696-00490042dcbd.png</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>963</th>\n",
       "      <td>fe028821-9968-473c-8d47-836ff3e6699a.png</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>964</th>\n",
       "      <td>fe41a271-2bba-40aa-b153-14a54be4d8a1.png</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>965</th>\n",
       "      <td>feaee038-cbed-4fbe-a7ab-1ab491807a7c.png</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>966</th>\n",
       "      <td>fef7660d-2c03-40da-b0c3-7bfc916d4a56.png</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>967</th>\n",
       "      <td>ff38d59e-9a11-41e4-901b-67097bb0e960.png</td>\n",
       "      <td>33</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>968 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                      ImageID  label\n",
       "0    005d3487-92a5-48f3-9784-379f1217825d.png      2\n",
       "1    00a6a262-1643-443f-b8ed-7fc075ac3c1e.png     20\n",
       "2    00f7f66d-96f6-4c9c-b476-dced05e7211f.png      6\n",
       "3    01a4038f-159e-4a73-bc65-2d20b3cc2a2d.png     17\n",
       "4    01cf52dd-4235-403b-9696-00490042dcbd.png      4\n",
       "..                                        ...    ...\n",
       "963  fe028821-9968-473c-8d47-836ff3e6699a.png     14\n",
       "964  fe41a271-2bba-40aa-b153-14a54be4d8a1.png     21\n",
       "965  feaee038-cbed-4fbe-a7ab-1ab491807a7c.png     13\n",
       "966  fef7660d-2c03-40da-b0c3-7bfc916d4a56.png     18\n",
       "967  ff38d59e-9a11-41e4-901b-67097bb0e960.png     33\n",
       "\n",
       "[968 rows x 2 columns]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f350b9a1-82a7-4051-8492-3dc762d8aaab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
